import numpy as np
from math import sqrt
from collections import Counter


class KnnClassifier:
    def __init__(self, k):
        """初始化kNN分类器"""
        assert k >= 1, "k must be valid"
        self.k = k
        self._x_train = None
        self._y_train = None

    def fit(self, x_train, y_train):
        """根据训练数据集x_train和y_train训练kNN分类器"""
        assert x_train.shape[0] == y_train.shape[0], "the size of x_trans must equal to the size of y_train "
        assert self.k <= x_train.shape[0], "the size of x_train must be at least k."

        self._x_train = x_train
        self._y_train = y_train
        return self

    def predict(self, x_predict):
        assert self._x_train is not None and self._y_train is not None, "must fit before predict!"
        assert x_predict.shape[1] == self._x_train.shape[1], "the feature number of x_predict must be equal to x_train"

        y_predict = [self._predict(x) for x in x_predict]
        return np.array(y_predict)

    def _predict(self, x):
        assert x.shape[0] == self._x_train.shape[1], "the feature number of x must be equal to x_train"

        distances = [sqrt(np.sum((x_train - x) ** 2)) for x_train in self._x_train]
        nearest = np.argsort(distances)
        filter_element = [self._y_train[i] for i in nearest[:self.k]]
        votes = Counter(filter_element)

        return votes.most_common(1)[0][0]

    def __repr__(self):
        return "KNN(k=%d)" % self.k

    def knn_classify(k, x_train, y_train, new_point):
        assert 1 <= k <= x_train.shape[0], "k must be valid"
        assert x_train.shape[0] == y_train.shape[0], "the size of X_trans must equal to the size of Y_train "
        assert x_train.shape[1] == new_point.shape[0], "the feature number of x must be equal to X_train"

        distances = [sqrt(np.sum((x - new_point) ** 2)) for x in x_train]
        nearest = np.argsort(distances)
        filter_element = [y_train[i] for i in nearest[:k]]
        votes = Counter(filter_element)

        return votes.most_common(1)[0][0]
